import os

# Use JAX backend
os.environ["KERAS_BACKEND"] = "jax"  # @param ["tensorflow", "jax", "torch"]

# Using GPU 1 only
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import math
import tensorflow as tf
import tensorflow_datasets as tfds
import keras
from keras import layers
from keras import activations
import typing_extensions as tx
import typing
from src.efficient_attention import EfficientAttention
from src.standard_attention import StandardMultiHeadAttention
from src.optimised_attention import OptimisedAttention
from src.super_attention import SuperAttention
import jax.numpy as jnp
import jax
from keras.initializers import RandomNormal
from keras.callbacks import Callback
import time
import pandas as pd
import matplotlib.pyplot as plt

# Load the ImageNet dataset
# Prepare the data
num_classes = 1000
input_shape = (112, 112, 3)
batch_size = 128
dropout_rate = 0.15

PATCH_SIZE = 8

NUM_PATCHES = (input_shape[0] // PATCH_SIZE) * (input_shape[1] // PATCH_SIZE)


# Download ImageNet data
cache_dir = "/mnt/storage/datasets/imagenet"
os.makedirs(cache_dir, exist_ok=True)
train_ds = tfds.load('imagenet2012', split='train', with_info=False, data_dir=cache_dir)
val_ds = tfds.load('imagenet2012', split='validation', with_info=False, data_dir=cache_dir)

def preprocess_image(image):
    image = tf.image.resize(image, input_shape[:2])
    image = tf.image.random_flip_left_right(image)
    image = tf.clip_by_value(image / 255.0, 0.0, 1.0)
    return image


# convert train_ds to map to 'image, label' format with lambda function
train_dataset = (
    train_ds.map(lambda x: (preprocess_image(x['image']), x['label']), num_parallel_calls=tf.data.AUTOTUNE)
    .shuffle(10 * batch_size)
    .batch(batch_size, drop_remainder=True)
    .prefetch(tf.data.AUTOTUNE)
)


val_dataset = (
    val_ds.map(lambda x: (preprocess_image(x['image']), x['label']), num_parallel_calls=tf.data.AUTOTUNE)
    .batch(batch_size, drop_remainder=True)
    .prefetch(tf.data.AUTOTUNE)
)


ImageSizeArg = typing.Union[typing.Tuple[int, int], int]



# Configure the hyperparameters
num_epochs = 20
GRAD_ACCUM_STEPS = math.ceil(1024 / batch_size)

max_learning_rate = 1.00e-4
min_learning_rate = 3.95e-5

# learning rate scheduler
# create a learning rate scheduler callback
lr_step = 2 ** ((math.log(min_learning_rate, 2) - math.log(max_learning_rate, 2))/num_epochs)
def lr_scheduler(epoch, lr):
    return lr * lr_step



class AppendClassToken(layers.Layer):
    def __init__(self, **kwargs):
        super(AppendClassToken, self).__init__(**kwargs)

    def build(self, input_shape):
        self.hidden_size = input_shape[-1]
        initializer = RandomNormal(mean=0., stddev=0.02)

        self.class_token = self.add_weight(name='class_token_weight',
                                           shape=(1, 1, self.hidden_size),
                                           initializer=initializer,
                                           trainable=True)
        super(AppendClassToken, self).build(input_shape)

    def call(self, x):
        class_token = jnp.tile(self.class_token, (x.shape[0], 1, 1))
        x = jnp.concatenate((class_token, x), axis=1)
        return x

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[1] + 1, self.hidden_size)

    def get_config(self):
        config = super(AppendClassToken, self).get_config()
        return config


class AddPositionEmbs(layers.Layer):
    def __init__(self, **kwargs):
        super(AddPositionEmbs, self).__init__(**kwargs)
        self.stddev = 0.06

    def build(self, input_shape):
        pos_emb_shape = (1, input_shape[1], input_shape[2])
        self.pos_embedding = self.add_weight(name='pos_embedding',
                                             shape=pos_emb_shape,
                                             initializer=RandomNormal(stddev=self.stddev),
                                             trainable=True)
        super(AddPositionEmbs, self).build(input_shape)

    def call(self, inputs):
        pos_embedding = jnp.tile(self.pos_embedding, (inputs.shape[0], 1, 1))
        return inputs + pos_embedding

    def compute_output_shape(self, input_shape):
        return input_shape

    def get_config(self):
        config = super(AddPositionEmbs, self).get_config()
        config.update({'stddev': self.stddev})
        return config


class TransformerBlock(layers.Layer):
    """Implements a Transformer block."""

    def __init__(self, *args, attn_arch, num_heads, mlp_dim, dropout, **kwargs):
        super().__init__(*args, **kwargs)
        self.num_heads = num_heads
        self.mlp_dim = mlp_dim
        self.dropout = dropout
        self.attn_arch = attn_arch

    def build(self, input_shape):
        self.att = self.attn_arch(
            key_dim=input_shape[-1] // self.num_heads,
            num_heads=self.num_heads,
        )
        self.mlpblock = keras.Sequential(
            [
                layers.Dense(
                    self.mlp_dim,
                    activation="linear",
                    name=f"{self.name}-Dense_0",
                ),
                layers.Lambda(
                    lambda x: activations.gelu(x, approximate=False)
                )
                if hasattr(activations, "gelu")
                else layers.Lambda(
                    lambda x: activations.gelu(x, approximate=False)
                ),
                layers.Dropout(self.dropout),
                layers.Dense(input_shape[-1], name=f"{self.name}-Dense_1"),
                layers.Dropout(self.dropout),
            ],
            name="MlpBlock_3",
        )
        self.layernorm1 = layers.LayerNormalization(
            epsilon=1e-6, name="LayerNorm_0"
        )
        self.layernorm2 = layers.LayerNormalization(
            epsilon=1e-6, name="LayerNorm_2"
        )
        self.dropout_layer = layers.Dropout(self.dropout)

    def call(self, inputs, training):
        x = self.layernorm1(inputs)
        x, weights = self.att(x, x, return_attention_scores=True, training=training)
        x = self.dropout_layer(x, training=training)
        x = x + inputs
        y = self.layernorm2(x)
        y = self.mlpblock(y)
        return x + y, weights

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "num_heads": self.num_heads,
                "mlp_dim": self.mlp_dim,
                "dropout": self.dropout,
            }
        )
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)


def build_model(
    attn_arch,
    image_size: ImageSizeArg,
    patch_size: int,
    num_layers: int,
    hidden_size: int,
    num_heads: int,
    name: str,
    mlp_dim: int,
    classes: int,
    dropout=0.1,
    activation="linear",
):
    """Build a ViT model.

    Args:
        attn_arch: The SDPA algorithm to use.
        image_size: The size of input images.
        patch_size: The size of each patch (must fit evenly in image_size)
        classes: optional number of classes to classify images
            into, only to be specified if `include_top` is True, and
            if no `weights` argument is specified.
        num_layers: The number of transformer layers to use.
        hidden_size: The number of filters to use
        num_heads: The number of transformer heads
        mlp_dim: The number of dimensions for the MLP output in the transformers.
        dropout_rate: fraction of the units to drop for dense layers.
        activation: The activation to use for the final layer.
    """
    assert (image_size[0] % patch_size == 0) and (
        image_size[1] % patch_size == 0
    ), "image_size must be a multiple of patch_size"
    x = layers.Input(shape=(image_size[0], image_size[1], 3), batch_size=batch_size)
    y = layers.Conv2D(
        filters=hidden_size,
        kernel_size=patch_size,
        strides=patch_size,
        padding="valid",
        name="embedding",
    )(x)
    y = layers.Reshape((y.shape[1] * y.shape[2], hidden_size))(y)
    y = AppendClassToken(name="class_token")(y)
    y = AddPositionEmbs(name="Transformer-posembed_input")(y)
    for n in range(num_layers):
        y, _ = TransformerBlock(
            attn_arch=attn_arch,
            num_heads=num_heads,
            mlp_dim=mlp_dim,
            dropout=dropout,
        )(y, training=True)

    y = layers.LayerNormalization(
        epsilon=1e-6, name="Transformer-encoder_norm"
    )(y)
    y = layers.MaxPooling1D(pool_size=NUM_PATCHES+1)(y)
    y = layers.Flatten()(y)
    y = layers.Dense(2 * classes, name="pre_head", activation="gelu")(y)
    y = layers.Dropout(dropout)(y)
    y = layers.Dense(classes, name="head", activation=activation)(y)
    return tf.keras.models.Model(inputs=x, outputs=y, name=name)



def vit_b16(
    attn_arch,
    image_size: ImageSizeArg = input_shape[:2],
    classes=1000,
    num_heads=12,
    num_layers=8,
    hidden_size=768,
    mlp_dim=3072,
    dropout=0.1,
    activation="linear",
):
    """Builds a ViT-B/16 model."""

    model = build_model(
        attn_arch=attn_arch,
        image_size=image_size,
        patch_size=PATCH_SIZE,
        num_layers=num_layers,
        hidden_size=hidden_size,
        num_heads=num_heads,
        name="vit-b16",
        mlp_dim=mlp_dim,
        classes=classes,
        dropout=dropout,
        activation=activation,
    )
    return model

class TimeHistory(Callback):
    def on_train_begin(self, logs={}):
        self.times = []

    def on_epoch_begin(self, epoch, logs={}):
        self.epoch_start_time = time.time()

    def on_epoch_end(self, epoch, logs={}):
        self.times.append(time.time() - self.epoch_start_time)






#Compile, train, and evaluate the mode
def run_experiment(model, arch_name="StandardMultiHeadAttention", run_number=1, number_of_heads=1):
    optimizer = keras.optimizers.AdamW(
        learning_rate=max_learning_rate, gradient_accumulation_steps=GRAD_ACCUM_STEPS)

    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )
    # model.summary()

    checkpoint_filepath = "./results/imagenet/model/" + arch_name + f"_H_{number_of_heads}" + "/run_num_" + str(run_number) + "/" + "model.weights.h5"
    history_filepath = "./results/imagenet/history/" + arch_name + f"_H_{number_of_heads}" + "/run_num_" + str(run_number) + "/history"+".csv"
    test_history_filepath = "./results/imagenet/history/" + arch_name + f"_H_{number_of_heads}" + "/run_num_" + str(run_number) + "/test_history"+".csv"
    general_info_filepath = "./results/imagenet/history/" + arch_name + f"_H_{number_of_heads}" + "/run_num_" + str(run_number) + "/general_info"+".csv"
    # create the directories if not exist
    os.makedirs(os.path.dirname(checkpoint_filepath), exist_ok=True)
    os.makedirs(os.path.dirname(history_filepath), exist_ok=True)
    os.makedirs(os.path.dirname(test_history_filepath), exist_ok=True)
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    time_callback = TimeHistory()

    # learning rate scheduler callback
    lr_scheduler_callback = keras.callbacks.LearningRateScheduler(lr_scheduler, verbose=1)

    history = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=num_epochs,
        callbacks=[checkpoint_callback, time_callback, lr_scheduler_callback],
    )

    # Save History till 5 decimal places
    history_df = pd.DataFrame(history.history)
    history_df = history_df.round(5)
    history_df.to_csv(history_filepath, sep='\t', index=False)

    model.load_weights(checkpoint_filepath)
    loss, accuracy, top_5_accuracy = model.evaluate(val_dataset)

    # print(f"Test loss: {round(loss, 3)}")
    # print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    # print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    # Save the test results with 5 decimal places
    test_history_df = pd.DataFrame([[round(loss, 5), round(accuracy, 5), round(top_5_accuracy, 5)]], columns=["loss", "accuracy", "top-5-accuracy"])
    # print the headers in the first row and then the values in the second row
    test_history_df.to_csv(test_history_filepath, sep='\t', index=False)
    # save to general info file number of attention parameters and
    num_of_attention_params = model.layers[5].att.count_params()
    average_epoch_time_without_first_epoch = sum(time_callback.times[1:]) / (len(time_callback.times) - 1)
    average_epoch_time = sum(time_callback.times) / len(time_callback.times)
    #save to general info file with first row as header and second row as values with 3 decimal places
    general_info_pd = pd.DataFrame([[num_of_attention_params,
                                     round(average_epoch_time_without_first_epoch, 3), round(average_epoch_time, 3)]],
                                   columns=["num_of_attention_params",
                                            "average_epoch_time_excluding_first_epoch", "average_epoch_time"])
    general_info_pd.to_csv(general_info_filepath, sep='\t', index=False)

    return history



def plot_history(item, history, model_name="StandardMultiHeadAttention", run_number=0, number_of_heads=1):
    title = "Train and Validation {} for {} with {} heads in run {}".format(item, model_name, number_of_heads, run_number)
    plt.plot(history.history[item], label=item)
    plt.plot(history.history["val_" + item], label="val_" + item)
    plt.xlabel("Epochs")
    plt.ylabel(item)
    plt.title(title, fontsize=14)
    plt.legend()
    plt.grid()
    plt.show()
    # save the plot to the history directory
    plot_filepath = "./results/imagenet/plots/" + model_name + f"_H_{number_of_heads}" + "/run_num_" + str(
        run_number) + "/history" + ".pdf"
    os.makedirs(os.path.dirname(plot_filepath), exist_ok=True)
    # tight_layout and save as pdf
    plt.tight_layout()
    plt.savefig(plot_filepath)



# Different Archihtectures for Attention

ATTENTION_ARCHS = [SuperAttention, StandardMultiHeadAttention, OptimisedAttention, EfficientAttention]
NUM_OF_HEADS = [12]

NUM_OF_RUNS = 1
# use enumerate to get the index of each element
for run_number in range(NUM_OF_RUNS):
    for (attention_arch, num_of_heads) in zip(ATTENTION_ARCHS, NUM_OF_HEADS):
            vit_classifier = vit_b16(
                attention_arch,
                num_heads=num_of_heads, dropout=dropout_rate)
            vit_classifier.summary()
            history = run_experiment(vit_classifier, attention_arch.__name__, run_number, number_of_heads=num_of_heads)
            plot_history("loss", history, attention_arch.__name__, run_number)
            plot_history("accuracy", history, attention_arch.__name__, run_number)
            plot_history("top-5-accuracy", history, attention_arch.__name__, run_number)
